'''
模型的训练代码 
'''
from pickle import load
import time

import numpy as np

from ml_models import MLPriceModel
from constant import CHECK_POINT,DatasetLabelPath,DatasetTrainPath,LabelScalerPath

data = np.load(DatasetTrainPath).astype(np.float32)
print('data.shape',data.shape)
label = np.load(DatasetLabelPath).astype(np.float32)
with open(LabelScalerPath,'br') as f:
    label_scaler = load(f)

model = MLPriceModel('MLP')
train_at = time.time()
model.fit(data, label)
print('traun consume', time.time() - train_at)
model.save_to_onnx(data.shape[-1],f'{CHECK_POINT}{model.get_name()}.onnx')